#include "PCFG_v2.h"
#include <mpi.h>
#include <algorithm>
using namespace std;

void PriorityQueue::CalProb(PT &pt)
{
  // 计算PriorityQueue里面一个PT的流程如下：
  // 1. 首先需要计算一个PT本身的概率。例如，L6S1的概率为0.15
  // 2. 需要注意的是，Queue里面的PT不是“纯粹的”PT，而是除了最后一个segment以外，全部被value实例化的PT
  // 3. 所以，对于L6S1而言，其在Queue里面的实际PT可能是123456S1，其中“123456”为L6的一个具体value。
  // 4. 这个时候就需要计算123456在L6中出现的概率了。假设123456在所有L6 segment中的概率为0.1，那么123456S1的概率就是0.1*0.15

  // 计算一个PT本身的概率。后续所有具体segment value的概率，直接累乘在这个初始概率值上
  pt.prob = pt.preterm_prob;

  // index: 标注当前segment在PT中的位置
  int index = 0;

  for (int idx : pt.curr_indices)
  {
    // pt.content[index].PrintSeg();
    if (pt.content[index].type == 1)
    {
      // 下面这行代码的意义：
      // pt.content[index]：目前需要计算概率的segment
      // m.FindLetter(seg): 找到一个letter segment在模型中的对应下标
      // m.letters[m.FindLetter(seg)]：一个letter segment在模型中对应的所有统计数据
      // m.letters[m.FindLetter(seg)].ordered_values：一个letter segment在模型中，所有value的总数目
      pt.prob *= m.letters[m.FindLetter(pt.content[index])].ordered_freqs[idx];
      pt.prob /= m.letters[m.FindLetter(pt.content[index])].total_freq;
      // cout << m.letters[m.FindLetter(pt.content[index])].ordered_freqs[idx] << endl;
      // cout << m.letters[m.FindLetter(pt.content[index])].total_freq << endl;
    }
    if (pt.content[index].type == 2)
    {
      pt.prob *= m.digits[m.FindDigit(pt.content[index])].ordered_freqs[idx];
      pt.prob /= m.digits[m.FindDigit(pt.content[index])].total_freq;
      // cout << m.digits[m.FindDigit(pt.content[index])].ordered_freqs[idx] << endl;
      // cout << m.digits[m.FindDigit(pt.content[index])].total_freq << endl;
    }
    if (pt.content[index].type == 3)
    {
      pt.prob *= m.symbols[m.FindSymbol(pt.content[index])].ordered_freqs[idx];
      pt.prob /= m.symbols[m.FindSymbol(pt.content[index])].total_freq;
      // cout << m.symbols[m.FindSymbol(pt.content[index])].ordered_freqs[idx] << endl;
      // cout << m.symbols[m.FindSymbol(pt.content[index])].total_freq << endl;
    }
    index += 1;
  }
  // cout << pt.prob << endl;
}

void PriorityQueue::init()
{
  // cout << m.ordered_pts.size() << endl;
  // 用所有可能的PT，按概率降序填满整个优先队列
  for (PT pt : m.ordered_pts)
  {
    for (segment seg : pt.content)
    {
      if (seg.type == 1)
      {
        // 下面这行代码的意义：
        // max_indices用来表示PT中各个segment的可能数目。例如，L6S1中，假设模型统计到了100个L6，那么L6对应的最大下标就是99
        // （但由于后面采用了"<"的比较关系，所以其实max_indices[0]=100）
        // m.FindLetter(seg): 找到一个letter segment在模型中的对应下标
        // m.letters[m.FindLetter(seg)]：一个letter segment在模型中对应的所有统计数据
        // m.letters[m.FindLetter(seg)].ordered_values：一个letter segment在模型中，所有value的总数目
        pt.max_indices.emplace_back(m.letters[m.FindLetter(seg)].ordered_values.size());
      }
      if (seg.type == 2)
      {
        pt.max_indices.emplace_back(m.digits[m.FindDigit(seg)].ordered_values.size());
      }
      if (seg.type == 3)
      {
        pt.max_indices.emplace_back(m.symbols[m.FindSymbol(seg)].ordered_values.size());
      }
    }
    pt.preterm_prob = float(m.preterm_freq[m.FindPT(pt)]) / m.total_preterm;
    // pt.PrintPT();
    // cout << " " << m.preterm_freq[m.FindPT(pt)] << " " << m.total_preterm << " " << pt.preterm_prob << endl;

    // 计算当前pt的概率
    CalProb(pt);
    // 将PT放入优先队列
    priority.emplace_back(pt);
  }
  // cout << "priority size:" << priority.size() << endl;
}

void PriorityQueue::PopNext(double &time)
{

  // 对优先队列最前面的PT，首先利用这个PT生成一系列猜测
  Generate(priority.front(),time);

  // 然后需要根据即将出队的PT，生成一系列新的PT
  vector<PT> new_pts = priority.front().NewPTs();
  for (PT pt : new_pts)
  {
    // 计算概率
    CalProb(pt);
    // 接下来的这个循环，作用是根据概率，将新的PT插入到优先队列中
    for (auto iter = priority.begin(); iter != priority.end(); iter++)
    {
      // 对于非队首和队尾的特殊情况
      if (iter != priority.end() - 1 && iter != priority.begin())
      {
        // 判定概率
        if (pt.prob <= iter->prob && pt.prob > (iter + 1)->prob)
        {
          priority.emplace(iter + 1, pt);
          break;
        }
      }
      if (iter == priority.end() - 1)
      {
        priority.emplace_back(pt);
        break;
      }
      if (iter == priority.begin() && iter->prob < pt.prob)
      {
        priority.emplace(iter, pt);
        break;
      }
    }
  }

  // 现在队首的PT善后工作已经结束，将其出队（删除）
  priority.erase(priority.begin());
}

// 这个函数你就算看不懂，对并行算法的实现影响也不大
// 当然如果你想做一个基于多优先队列的并行算法，可能得稍微看一看了
vector<PT> PT::NewPTs()
{
  // 存储生成的新PT
  vector<PT> res;

  // 假如这个PT只有一个segment
  // 那么这个segment的所有value在出队前就已经被遍历完毕，并作为猜测输出
  // 因此，所有这个PT可能对应的口令猜测已经遍历完成，无需生成新的PT
  if (content.size() == 1)
  {
    return res;
  }
  else
  {
    // 最初的pivot值。我们将更改位置下标大于等于这个pivot值的segment的值（最后一个segment除外），并且一次只更改一个segment
    // 上面这句话里是不是有没看懂的地方？接着往下看你应该会更明白
    int init_pivot = pivot;

    // 开始遍历所有位置值大于等于init_pivot值的segment
    // 注意i < curr_indices.size() - 1，也就是除去了最后一个segment（这个segment的赋值预留给并行环节）
    for (int i = pivot; i < curr_indices.size() - 1; i += 1)
    {
      // curr_indices: 标记各segment目前的value在模型里对应的下标
      curr_indices[i] += 1;

      // max_indices：标记各segment在模型中一共有多少个value
      if (curr_indices[i] < max_indices[i])
      {
        // 更新pivot值
        pivot = i;
        res.emplace_back(*this);
      }

      // 这个步骤对于你理解pivot的作用、新PT生成的过程而言，至关重要
      curr_indices[i] -= 1;
    }
    pivot = init_pivot;
    return res;
  }

  return res;
}

// 这个函数是PCFG并行化算法的主要载体
// 尽量看懂，然后进行并行实现
void PriorityQueue::Generate(PT pt,double &time)
{
  // // 获取当前进程的rank和size
  // int rank, size;
  // MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  // MPI_Comm_size(MPI_COMM_WORLD, &size);

  // 计算PT的概率，这里主要是给PT的概率进行初始化
  CalProb(pt);

  // 对于只有一个segment的PT，直接遍历生成其中的所有value即可
  if (pt.content.size() == 1)
  {
    // 指向最后一个segment的指针，这个指针实际指向模型中的统计数据
    segment *a;
    // 在模型中定位到这个segment
    if (pt.content[0].type == 1)
    {
      a = &m.letters[m.FindLetter(pt.content[0])];
    }
    if (pt.content[0].type == 2)
    {
      a = &m.digits[m.FindDigit(pt.content[0])];
    }
    if (pt.content[0].type == 3)
    {
      a = &m.symbols[m.FindSymbol(pt.content[0])];
    }

    // Multi-thread TODO：
    // 这个for循环就是你需要进行并行化的主要部分了，特别是在多线程&GPU编程任务中
    // 可以看到，这个循环本质上就是把模型中一个segment的所有value，赋值到PT中，形成一系列新的猜测
    // 这个过程是可以高度并行化的
    // for (int i = 0; i < pt.max_indices[0]; i += 1)
    // {
    //     string guess = a->ordered_values[i];
    //     guesses.emplace_back(guess);
    //     total_guesses += 1;
    // }

    // MPI
    MPI(a->ordered_values, pt.max_indices[0], guesses, time);
  }
  else{
    string guess;
    int seg_idx = 0;
    // 这个for循环的作用：给当前PT的所有segment赋予实际的值（最后一个segment除外）
    // segment值根据curr_indices中对应的值加以确定
    // 这个for循环你看不懂也没太大问题，并行算法不涉及这里的加速
    for (int idx : pt.curr_indices)
    {
      if (pt.content[seg_idx].type == 1)
      {
        guess += m.letters[m.FindLetter(pt.content[seg_idx])].ordered_values[idx];
      }
      if (pt.content[seg_idx].type == 2)
      {
        guess += m.digits[m.FindDigit(pt.content[seg_idx])].ordered_values[idx];
      }
      if (pt.content[seg_idx].type == 3)
      {
        guess += m.symbols[m.FindSymbol(pt.content[seg_idx])].ordered_values[idx];
      }
      seg_idx += 1;
      if (seg_idx == pt.content.size() - 1)
      {
        break;
      }
    }

    // 指向最后一个segment的指针，这个指针实际指向模型中的统计数据
    segment *a;
    if (pt.content[pt.content.size() - 1].type == 1)
    {
      a = &m.letters[m.FindLetter(pt.content[pt.content.size() - 1])];
    }
    if (pt.content[pt.content.size() - 1].type == 2)
    {
      a = &m.digits[m.FindDigit(pt.content[pt.content.size() - 1])];
    }
    if (pt.content[pt.content.size() - 1].type == 3)
    {
      a = &m.symbols[m.FindSymbol(pt.content[pt.content.size() - 1])];
    }

    // Multi-thread TODO：
    // 这个for循环就是你需要进行并行化的主要部分了，特别是在多线程&GPU编程任务中
    // 可以看到，这个循环本质上就是把模型中一个segment的所有value，赋值到PT中，形成一系列新的猜测
    // 这个过程是可以高度并行化的
    // for (int i = 0; i < pt.max_indices[pt.content.size() - 1]; i += 1)
    // {
    //     string temp = guess + a->ordered_values[i];
    //     guesses.emplace_back(temp);
    //     total_guesses += 1;
    // }
    
    // MPI
    MPI(a->ordered_values, pt.max_indices[pt.content.size() - 1], guesses,time, guess);
  }
}

/***
 * MPI函数
 * 要把val放到guesses后面
 */
void PriorityQueue::MPI(std::vector<string> &val, int num, std::vector<string> &guesses, double &total_comm_time, std::string guess)
{
  // 获取当前进程的rank和size
  int rank, size;
  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  MPI_Comm_size(MPI_COMM_WORLD, &size);
  
  // 统计通信阻塞时间
  double local_comm_time = 0.0;

  // 划分任务
  int chunk = num / size;//每个进程分配的任务大小
  int remainder = num % size;//剩余任务大小
  int start = rank * chunk + min(rank, remainder);
  int end = start + chunk + (rank < remainder ? 1 : 0);
  
  int sub_num = end - start;
  int local_total_len = 0;
  std::vector<int> local_lengths(sub_num);

  //记录偏移量，这样可以并行拼接
  std::vector<int> local_offsets(sub_num);

  //先检查自己的任务所有string的长度加起来有多少,同时记录每个的偏移量
  if(guess!=""){
    int len = guess.length();
    for (int i = start; i < end; i++)
    {
      local_offsets[i-start] =local_total_len;
      local_total_len += val[i].length() + len;
      local_lengths[i-start] = len + val[i].length();
    }
  }
  else{
    for (int i = start; i < end; i++)
    {
      local_offsets[i - start] = local_total_len;
      local_total_len += val[i].length();
      local_lengths[i - start] = val[i].length();
    }
  }
  
  std::vector<char> local_chars(local_total_len);
  
  if(guess!=""){
#pragma omp parallel for
    for (int i = start; i < end; ++i)
    {
      int idx = i - start;
      std::string s = guess + val[i];
      std::copy(
          s.begin(),
          s.end(),
          local_chars.begin() + local_offsets[idx]);
    }
  }
  else{
#pragma omp parallel for
    for (int i = start; i < end; ++i)
    {
      int idx = i - start;
      std::copy(
          val[i].begin(),
          val[i].end(),
          local_chars.begin() + local_offsets[idx]);
    }
  }

  // 开始收集结果
  // 1. 收集每个进程的猜测数量
  std::vector<int> counts(size);
  double comm_start = MPI_Wtime();
  MPI_Gather(&sub_num, 1, MPI_INT, counts.data(), 1, MPI_INT, 0, MPI_COMM_WORLD);
  double comm_end = MPI_Wtime();
  local_comm_time += (comm_end - comm_start);

  // 2. 收集每个进程的每个猜测长度
  std::vector<int> all_lengths; // 用于存储所有进程的猜测长度
  std::vector<int> displs_lengths(size, 0);   // 每个进程的猜测数量偏移量
  if (rank == 0)
  {
    int total = 0;
    for (int i = 0; i < size; ++i)
    {
      displs_lengths[i] = total;
      total += counts[i];
    }
    all_lengths.resize(total);
  }
  
  comm_start = MPI_Wtime();
  MPI_Gatherv(sub_num > 0 ? local_lengths.data() : NULL, sub_num, MPI_INT,
              all_lengths.data(), counts.data(), displs_lengths.data(),
              MPI_INT, 0, MPI_COMM_WORLD);
  comm_end = MPI_Wtime();
  local_comm_time += (comm_end - comm_start);

  // 3. 收集每个进程总字符的长度
  int local_chars_size = local_chars.size();
  std::vector<int> chars_counts(size);
  
  comm_start = MPI_Wtime();
  MPI_Gather(&local_chars_size, 1, MPI_INT, chars_counts.data(), 1, MPI_INT, 0, MPI_COMM_WORLD);
  comm_end = MPI_Wtime();
  local_comm_time += (comm_end - comm_start);

  // 4. 收集总字符串
  std::vector<int> displs_chars(size, 0);
  std::vector<char> all_chars;
  if (rank == 0)
  {
    int total = 0;
    for (int i = 0; i < size; ++i)
    {
      displs_chars[i] = total;
      total += chars_counts[i];
    }
    all_chars.resize(total);
  }
  
  comm_start = MPI_Wtime();
  MPI_Gatherv(local_chars_size > 0 ? local_chars.data() : NULL, local_chars_size, MPI_CHAR,
              all_chars.data(), chars_counts.data(), displs_chars.data(),
              MPI_CHAR, 0, MPI_COMM_WORLD);
  comm_end = MPI_Wtime();
  local_comm_time += (comm_end - comm_start);

  // 记录此次操作中的通信时间
  total_comm_time = 0.0;
  total_comm_time += local_comm_time;

  // 5. 插回到guesses中
  if (rank == 0)
  {
    size_t old_size = guesses.size();
    size_t new_size = old_size + all_lengths.size();
    guesses.resize(new_size);

    // 先算所有偏移量
    std::vector<size_t> all_offsets(all_lengths.size(), 0);
    size_t acc = 0;
    size_t a_l_size = all_lengths.size();
    for (size_t i = 0; i < a_l_size; ++i)
    {
      all_offsets[i] = acc;
      acc += all_lengths[i];
    }

// 并行生成
#pragma omp parallel for schedule(static)
    for (int i = 0; i < static_cast<int>(all_lengths.size()); ++i)
    {
      guesses[old_size + i] = std::string(
          all_chars.data() + all_offsets[i],
          all_lengths[i]);
    }
  }
}